/**
 * Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*!
 * \file flash_attention_score_grad_tiling_s1s2_bn2.cc
 * \brief
 */

#include "flash_attention_score_grad_tiling_s1s2_bn2.h"
#include "tiling/tiling_templates_registry.h"

namespace optiling {

constexpr uint32_t DIMS_FOUR = 4;
constexpr uint32_t DIMS_THREE = 3;
constexpr uint32_t HALF_PRECISION_SIZE = 2;
constexpr uint32_t FLOAT_PRECISION_SIZE = 4;
constexpr uint32_t MASK_FLOAT_VALUE = 64;
constexpr uint32_t MASK_HALF_VALUE = 128;
constexpr uint32_t USED_UB_BLOCK = 5;
constexpr uint32_t WORK_SPACE_BASE_CAL = 32 * 1024 * 1024;  // 100MB系统预留
constexpr uint32_t BLOCK = 32;                               // 32B
constexpr uint32_t B32 = 4;                                  // 4B
constexpr uint32_t B16 = 2;
constexpr uint32_t HIGH_PRECISION_0 = 0;
constexpr uint32_t HIGH_PERFORMANCE_1 = 1;
constexpr uint32_t BASE_LEN_128 = 128;
constexpr uint32_t BASE_LEN_256 = 256;
constexpr uint32_t BASE_LEN_512 = 512;
constexpr uint32_t BASE_LEN_64 = 64;
constexpr uint32_t BASE_LEN_32 = 32;
constexpr uint32_t INPUT_BUFFER_LEN = 60 * 1024;
constexpr uint32_t OUTPUT_BUFFER_LEN = 30 * 1024;
constexpr uint32_t NUMBER_2 = 2;
constexpr uint32_t BITNUM = 8;
constexpr uint32_t GM_ALIGN = 512;
constexpr uint32_t MAX_STRIDE_LIMIT = 65535;
constexpr uint32_t MM_RATIO_4 = 4;
constexpr uint32_t PSE_ALIBI_S1_SIZE = 1024;
constexpr uint32_t MATMUL_MINI_N = 16;
constexpr uint32_t VEC_REPEAT_LIMIT = 255;
constexpr uint32_t BAND_MODE = 3;
constexpr uint32_t POST_COEX_NODE = 3;
constexpr uint32_t BUFFER_NUM = 1;
const std::vector<gert::Shape> MODULE1_TEMPLATE_SUPPORTED_SHAPE = {
    // 5d shape [B, N,S1,S2,D]
    {24, 5, 9216, 9216, 64},
    {24, 5, 9216, 80, 64},
    {24, 10, 2304, 2304, 64},
};

enum KernelBranch { BRANCH_FP16_HIGH_PRECISION = 10001, BRANCH_FP16_HIGH_PERFORMANCE, BRANCH_BF16, BRANCH_FP32 };

bool FlashAttentionScoreGradTilingUs1s2Bbn2::IsCapable() {
  int64_t nNums = td_.opInfo.get_B() * td_.opInfo.get_N2();
  if (nNums < static_cast<int64_t>(aicoreParams_.blockDim)) {
    OPS_LOG_I(context_,
              "FlashAttentionScoreGradTilingUs1s2Bbn2 not support this case: nNums: %ld < coreNums: %lu",
              nNums, aicoreParams_.blockDim);
    return false;
  }
  return true;
}

bool FlashAttentionScoreGradTilingDeterministic::IsCapable() {
  OPS_LOG_D(context_, "Get deterministic flag is %d", context_->GetDeterministic());
  if (context_->GetDeterministic() == 1) {
    return true;
  }
  return false;
}

uint64_t FlashAttentionScoreGradTilingUs1s2Bbn2::GetTilingKey() const {
  bool is_bsh = tmpData_.layout == static_cast<uint32_t>(InputLayout::BSH) ||
                tmpData_.layout == static_cast<uint32_t>(InputLayout::BSND);
  bool is_sbh = tmpData_.layout == static_cast<uint32_t>(InputLayout::SBH);
  bool is_bnsd = tmpData_.layout == static_cast<uint32_t>(InputLayout::BNSD);
  bool is_tnd = tmpData_.layout == static_cast<uint32_t>(InputLayout::TND);

  uint64_t tilingKey = 0;
  CubeFormatEnum mmOutFormat = CubeFormatEnum::ND;
  MatmulConfig mm_config = MatmulConfig::NULL_CONFIG;
  DtypeEnum dtype_in = DtypeEnum::FLOAT16;
  LayoutEnum layout = LayoutEnum::BNSD;
  // current only BNSD
  if ((tmpData_.queryType == ge::DT_BF16)) {
    dtype_in = DtypeEnum::BFLOAT16;
  }
  if ((tmpData_.queryType == ge::DT_FLOAT)) {
    dtype_in = DtypeEnum::FLOAT32;
  }
  if ((tmpData_.s2 % (GM_ALIGN / B32)) != 0) {
    // 非cacheline对齐场景使能NZ优化
    mmOutFormat = CubeFormatEnum::NZ;
  }
  if (is_bsh) {
    layout = LayoutEnum::BSND;
  } else if (is_sbh) {
    if (tmpData_.b * tmpData_.n2 * tmpData_.d > MAX_STRIDE_LIMIT) {
      // SBH: B*H > 65535
      mm_config = MatmulConfig::NORMAL_CONFIG;
    }
    layout = LayoutEnum::SBND;
  } else if (is_bnsd) {
    layout = LayoutEnum::BNSD;
  } else if (is_tnd) {
    layout = LayoutEnum::TND;
    mmOutFormat = CubeFormatEnum::ND;
  }
  tilingKey = GET_TILINGKEY(AxisEnum::S2, AxisEnum::S1, AxisEnum::N2, dtype_in, layout, SparseEnum::ALL,
                            mm_config, mmOutFormat, tmpData_.pse_cfg, tmpData_.atten_mask_cfg, tmpData_.drop_out_cfg);

  OPS_LOG_I(context_, "FAGTiling Us1s2Bbn2 DoTiling success, tilingkey is %lu.", tilingKey);
  return tilingKey;
}

ge::graphStatus FlashAttentionScoreGradTilingUs1s2Bbn2::GetPlatformInfo() {
  auto platformInfoPtr = context_->GetPlatformInfo();
  if (platformInfoPtr == nullptr) {
    auto compileInfoPtr = reinterpret_cast<const FlashAttentionScoreGradCompileInfo*>(context_->GetCompileInfo());
    OPS_ERR_IF(compileInfoPtr == nullptr, OPS_REPORT_CUBE_INNER_ERR(context_, "compile_info is null"),
               return ge::GRAPH_FAILED);

    aicoreParams_.blockDim = compileInfoPtr->aivNum;
    aicoreParams_.aicNum = compileInfoPtr->aicNum;
    aicoreParams_.ubSize = compileInfoPtr->ubSize;
    aicoreParams_.l1Size = compileInfoPtr->l1Size;
    aicoreParams_.l0aSize = compileInfoPtr->l0aSize;
    aicoreParams_.l0bSize = compileInfoPtr->l0bSize;
    aicoreParams_.l0cSize = compileInfoPtr->l0cSize;
  } else {
    auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfoPtr);
    aicoreParams_.blockDim = ascendcPlatform.GetCoreNumAiv();
    aicoreParams_.aicNum = ascendcPlatform.GetCoreNumAic();
    ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, aicoreParams_.ubSize);
    ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::L1, aicoreParams_.l1Size);
    ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::L0_A, aicoreParams_.l0aSize);
    ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::L0_B, aicoreParams_.l0bSize);
    ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::L0_C, aicoreParams_.l0cSize);
  }
  return ge::GRAPH_SUCCESS;
}

void FlashAttentionScoreGradTilingUs1s2Bbn2::DecideBranch() {
  // according calcMode decide which params and branch
  tmpData_.branch = BRANCH_FP32;
  tmpData_.dataTypeSize = FLOAT_PRECISION_SIZE;
  tmpData_.mask = MASK_FLOAT_VALUE;
  tmpData_.queryType = static_cast<uint32_t>(context_->GetInputDesc(QUERY)->GetDataType());

  // 新需求默认走高精度，2023/9/26 by zhangguangjun
  tmpData_.calcMode = HIGH_PRECISION_0;
  OPS_LOG_D(context_, "calcMode is %d", tmpData_.calcMode);

  // assgin params
  if (tmpData_.queryType == ge::DT_FLOAT16) {
    if (tmpData_.calcMode == HIGH_PERFORMANCE_1) {
      tmpData_.branch = BRANCH_FP16_HIGH_PERFORMANCE;
      tmpData_.dataTypeSize = HALF_PRECISION_SIZE;
      tmpData_.mask = MASK_HALF_VALUE;
    } else {
      tmpData_.branch = BRANCH_FP16_HIGH_PRECISION;
    }
  } else if (tmpData_.queryType == ge::DT_BF16) {
    tmpData_.branch = BRANCH_BF16;
  }
}

ge::graphStatus FlashAttentionScoreGradTilingUs1s2Bbn2::SetMaskShapeType(
    const gert::Shape& storageShape, const uint32_t maskShapeDims) {
  if (maskShapeDims == DIM_2) {
    td_.opInfo.set_maskShapeType(ATTEN_MASK_SHAPE_TYPE_SS);
  } else if (maskShapeDims == DIM_4) {
    auto dim0 = storageShape.GetDim(DIM_0);
    auto dim1 = storageShape.GetDim(DIM_1);
    if ((dim0 == td_.opInfo.get_B()) && (dim1 == td_.opInfo.get_N2() * td_.opInfo.get_G())) {
      td_.opInfo.set_maskShapeType(ATTEN_MASK_SHAPE_TYPE_BNSS);
    } else if ((dim0 == td_.opInfo.get_B()) && (dim1 == 1)) {
      td_.opInfo.set_maskShapeType(ATTEN_MASK_SHAPE_TYPE_B1SS);
    } else if ((dim0 == 1) && (dim1 == 1)) {
      td_.opInfo.set_maskShapeType(ATTEN_MASK_SHAPE_TYPE_SS);
    } else {
      OPS_LOG_E(context_,
                "FlashAttentionScoreGradTilingUs1s2Bbn2 not support attenMaskShape dim0: %ld, dim1: %ld", dim0, dim1);
      return ge::GRAPH_FAILED;
    }
  } else {
    OPS_LOG_E(context_, "FlashAttentionScoreGradTilingUs1s2Bbn2 not support attenMaskShape dim num: %d", maskShapeDims);
    return ge::GRAPH_FAILED;
  }
  return ge::GRAPH_SUCCESS;
}

ge::graphStatus FlashAttentionScoreGradTilingUs1s2Bbn2::GetAttenMaskInfo() {
  /*
  Get attenmask input info
  */
  td_.opInfo.set_existAttenMask(0);
  attenMaskCompressMode = NO_COMPRESS_MODE;
  auto attenMaskShape = context_->GetOptionalInputShape(ATTEN_MASK);
  if (attenMaskShape != nullptr && attenMaskShape->GetStorageShape().GetDimNum() != 0) {
    // attenMask support [S1,S2](0) + [B,1,S1,S2](1) + [B,N1,S1,S2](2)
    td_.opInfo.set_existAttenMask(1);
    tmpData_.atten_mask_cfg = AttenMaskConfig::EXIST_ATTEN_MASK;
    auto storageShape = attenMaskShape->GetStorageShape();
    auto maskShapeDims = storageShape.GetDimNum();
    if (SetMaskShapeType(storageShape, maskShapeDims) != ge::GRAPH_SUCCESS) {
      return ge::GRAPH_FAILED;
    }

    attenMaskS2Size = storageShape.GetDim(maskShapeDims - LAST_AXIS_IDX);
    attenMaskS1Size = storageShape.GetDim(maskShapeDims - SEC_LAST_AXIS_IDX);

    if (sparseMode == LEFT_UP_CAUSAL) {
      attenMaskCompressMode = LEFT_UP_CAUSAL_MODE;
    } else if (sparseMode == RIGHT_DOWN_CAUSAL) {
      attenMaskCompressMode = RIGHT_DOWN_CAUSAL_MODE;
    } else if (sparseMode == BAND) {
      attenMaskCompressMode = BAND_MODE;
    } else if (sparseMode == PREFIX_COMPRESS) {
      attenMaskCompressMode = PREFIX_COMPRESS_MODE;
    }
    auto attenMaskDesc = context_->GetOptionalInputDesc(ATTEN_MASK);
    if (attenMaskDesc != nullptr) {
      auto attenMaskDtype = attenMaskDesc->GetDataType();
      if (attenMaskDtype == tmpData_.queryType) {
        td_.opInfo.set_maskDataType(ATTEN_MASK_TYPE_SAME);
      } else if (attenMaskDtype == ge::DT_BOOL || attenMaskDtype == ge::DT_UINT8) {
        td_.opInfo.set_maskDataType(ATTEN_MASK_TYPE_U8_BOOL);
      } else {
        OPS_LOG_E(context_, "FlashAttentionScoreGradTilingUs1s2Bbn2 not support attenMaskDtype: %d", attenMaskDtype);
        return ge::GRAPH_FAILED;
      }
    }
  }
  td_.opInfo.set_attenMaskS2Size(attenMaskS2Size);
  td_.opInfo.set_attenMaskCompressMode(attenMaskCompressMode);
  return ge::GRAPH_SUCCESS;
}

ge::graphStatus FlashAttentionScoreGradTilingUs1s2Bbn2::GetPseInfo() {
  /*
  Get pse input info
  */
  auto pseShape = context_->GetOptionalInputShape(PSE_SHIFT);
  if (pseShape != nullptr && pseShape->GetStorageShape().GetDimNum() != 0) {
    // pse support [B N S1 S2](0) + [B N 1 S2](1) + [1 N S1 S2](2)
    // pse support alibi: [B N H S2](3) + [1 N H S2](4)
    tmpData_.pse_cfg = PseConfig::EXIST_PSE;
    auto storageShape = pseShape->GetStorageShape();
    auto pseShapeDims = storageShape.GetDimNum();
    if (pseShapeDims == DIM_4) {
      auto dim0 = storageShape.GetDim(DIM_0);
      auto dim1 = storageShape.GetDim(DIM_1);
      auto dim2 = storageShape.GetDim(DIM_2);
      auto dim3 = storageShape.GetDim(DIM_3);
      // [B N S1 S2](0)  [B N 1 S2](1)  [1 N S1 S2](2)shape判断
      uint32_t shapeN1 = td_.opInfo.get_N2() * td_.opInfo.get_G();
      bool isBNS = (dim0 == td_.opInfo.get_B()) && (dim1 == shapeN1) && (dim3 == td_.opInfo.get_S2());
      bool isBNSS = isBNS && (dim2 == td_.opInfo.get_S1());
      bool isBN1S = isBNS && (dim2 == 1);
      bool is1NSS = (dim0 == 1) && (dim1 == shapeN1) && (dim2 == td_.opInfo.get_S1()) && (dim3 == td_.opInfo.get_S2());
      // [B N H S2](4)  [1 N H S2](5) shape判断
      bool isNHS = (dim1 == shapeN1) && (dim2 == PSE_ALIBI_S1_SIZE) && (dim3 == td_.opInfo.get_S2());
      // alibi条件1: S1 = S2 且 S1 >= 1024
      // alibi条件2: 下三角
      bool alibiCondi1 = (td_.opInfo.get_S1() == td_.opInfo.get_S2()) && (td_.opInfo.get_S1() >= PSE_ALIBI_S1_SIZE);
      bool alibiCondi2 = (td_.opInfo.get_preTokens() >= int32_t(td_.opInfo.get_S1())) && (td_.opInfo.get_nextTokens() == 0);
      bool isAlibiBNHS = isNHS && (dim0 == td_.opInfo.get_B()) && alibiCondi1 && alibiCondi2;
      bool isAlibi1NHS = isNHS && (dim0 == 1) && alibiCondi1 && alibiCondi2;
      bool isUnpad = td_.opInfo.get_layout() == static_cast<uint32_t>(InputLayout::TND);
      // 设置shape类型
      if (isUnpad && isNHS && (dim0 == 1) && alibiCondi2) {
        td_.opInfo.set_pseShapeType(PSE_SHAPE_TYPE_1NHS);
      } else if (isUnpad && isNHS && (dim0 == td_.opInfo.get_B()) && alibiCondi2) {
        td_.opInfo.set_pseShapeType(PSE_SHAPE_TYPE_BNHS);
      } else if (is1NSS) {
        td_.opInfo.set_pseShapeType(PSE_SHAPE_TYPE_1NSS);
      } else if (isBN1S) {
        td_.opInfo.set_pseShapeType(PSE_SHAPE_TYPE_BN1S);
      } else if (isBNSS) {
        td_.opInfo.set_pseShapeType(PSE_SHAPE_TYPE_BNSS);
      } else if (isAlibi1NHS) {
        td_.opInfo.set_pseShapeType(PSE_SHAPE_TYPE_1NHS);
      } else if (isAlibiBNHS) {
        td_.opInfo.set_pseShapeType(PSE_SHAPE_TYPE_BNHS);
      } else {
        OPS_LOG_E(context_, "The shape of pse[%ld,%ld,%ld,%ld] is invalid or tocken[%d,%d] not casual", dim0, dim1,
                  dim2, dim3, td_.opInfo.get_preTokens(), td_.opInfo.get_nextTokens());
        return ge::GRAPH_FAILED;
      }
    } else {
      OPS_LOG_E(context_, "FlashAttentionScoreGradTilingUs1s2Bbn2 not support pseShape dim num: %zu", pseShapeDims);
      return ge::GRAPH_FAILED;
    }
  }
  return ge::GRAPH_SUCCESS;
}

bool FlashAttentionScoreGradTilingUs1s2Bbn2::SparseTokenProcess() {
    int32_t preTokens = td_.opInfo.get_preTokens();
    int32_t nextTokens = td_.opInfo.get_nextTokens();
    OPS_LOG_D(context_, "preTokens: %d, nextTokens: %d, sparseMode: %d.", preTokens, nextTokens, sparseMode);

    if (sparseMode == NO_MASK) {
        if (int32_t(td_.opInfo.get_S1()) <= preTokens && nextTokens == 0) { // 原causal 场景
            td_.opInfo.set_isSparse(1);
        } else if (int32_t(td_.opInfo.get_S1()) > preTokens || int32_t(td_.opInfo.get_S2()) > nextTokens) { // 原band场景
            td_.opInfo.set_isSparse(1);
        } else {
            OPS_LOG_I(context_,
                      "FAG Us1s2Bbn2 sparseMode[%d] with s1[%u] s2[%u] preToken[%d] nextToken[%d] is not support!",
                      sparseMode, td_.opInfo.get_S1(), td_.opInfo.get_S2(), preTokens, nextTokens);
        }
        return false;
    }

    if (sparseMode != LEFT_UP_CAUSAL && sparseMode != RIGHT_DOWN_CAUSAL && sparseMode != BAND &&
        sparseMode != PREFIX && sparseMode != PREFIX_COMPRESS && sparseMode != RIGHT_DOWN_CASUAL_BAND &&
        sparseMode != BAND_LEFT_UP_CASUAL) {
      OPS_LOG_W(context_, "FAG Us1s2Bbn2 sparseMode[%d] is not support!", sparseMode);
      return false;
    }

    if (sparseMode == BAND) {
        // BAND模式tokens以右下角为基准
        if (int32_t(td_.opInfo.get_S2()) <= preTokens && int32_t(td_.opInfo.get_S1()) <= nextTokens) {
            OPS_LOG_W(context_,
                      "FAG Us1s2Bbn2 sparseMode is band, but preToken[%d] or nextToken[%d] is invalid "
                      "with s1[%u] s2[%u]!",
                      preTokens, nextTokens, td_.opInfo.get_S1(), td_.opInfo.get_S2());
            return false;
        }
    }
    return true;
}

bool FlashAttentionScoreGradTilingUs1s2Bbn2::ProcessPrefix() {
    if (sparseMode != PREFIX && sparseMode != PREFIX_COMPRESS) {
        return true;
    }

    auto prefixNTensor = context_->GetOptionalInputTensor(PREFIX_N);
    if (prefixNTensor == nullptr) {
        OPS_LOG_W(context_, "FAG Us1s2Bbn2 sparseMode is prefix, but prefixN tensor is null!");
        return false;
    }

    auto &prefixShape = prefixNTensor->GetShape().GetStorageShape();
    if (prefixShape.GetDimNum() != 1 || prefixShape.GetDim(0) != td_.opInfo.get_B()) {
        OPS_LOG_W(context_,
                  "FAG Us1s2Bbn2 sparseMode is prefix, but prefixshape size[%zu] or value is invalid!",
                  prefixShape.GetDimNum());
        return false;
    }

    std::vector<int64_t> prefixN;
    const int64_t* value = prefixNTensor->GetData<int64_t>();
    if (value == nullptr) {
        OPS_LOG_W(context_, "FAG Us1s2Bbn2 sparseMode is prefix, but prefixN data is null pointer!");
        return false;
    }
    auto shapeSize = prefixNTensor->GetShapeSize();
    for (auto i = 0; i < shapeSize; i++) {
        prefixN.push_back(value[i]);
    }

    if (prefixN.size() == td_.opInfo.get_B() && prefixN.size() < BATCH_MAX_SIZE) {
        uint32_t tempPrefixN[BATCH_MAX_SIZE];
        std::copy(prefixN.begin(), prefixN.end(), tempPrefixN);
        td_.opInfo.set_prefixN(tempPrefixN);
    } else {
        OPS_LOG_W(context_, "FAG Us1s2Bbn2 sparseMode is prefix, but prefixN value is invalid!");
        return false;
    }

    return true;
}

ge::graphStatus FlashAttentionScoreGradTilingUs1s2Bbn2::SetSparseParams() {
    td_.opInfo.set_isSparse(0);
    td_.opInfo.set_sparseMode(sparseMode);
    auto attenMaskShape = context_->GetOptionalInputShape(ATTEN_MASK);
    auto attenMaskDesc = context_->GetOptionalInputDesc(ATTEN_MASK);
    if (sparseMode == ALL_MASK || attenMaskShape == nullptr || attenMaskDesc == nullptr) {
      OPS_LOG_I(context_,
                "FAG Us1s2Bbn2 sparseMode[%d] is all_mask or attenmask is nullptr,"
                "not support sparse mode!", sparseMode);
      return ge::GRAPH_SUCCESS;
    }

    if (!SparseTokenProcess()) {
      return ge::GRAPH_SUCCESS;
    }

    auto status = ProcessPrefix();
    if (status == false) {
      if (sparseMode != PREFIX_COMPRESS) {
        return ge::GRAPH_SUCCESS;
      }

      OPS_LOG_E(context_, "Sparse capability must be supported under prefix compress mode, pls check input params");
      return ge::GRAPH_FAILED;
    }

    td_.opInfo.set_isSparse(1);
    return ge::GRAPH_SUCCESS;
}

void FlashAttentionScoreGradTilingUs1s2Bbn2::SetBandIdx() {
  if (sparseMode == RIGHT_DOWN_CASUAL_BAND) {
    for (int i = tmpData_.b - 1; i >= 0; i--) {
      if (tmpData_.actualSeqQlen[i] != 0) {
        td_.opInfo.set_bandIdx(i);
        return;
      }
    }
  } else if (sparseMode == BAND_LEFT_UP_CASUAL) {
    for (size_t i = 0; i < tmpData_.b; i++) {
      if (tmpData_.actualSeqQlen[i] != 0) {
        td_.opInfo.set_bandIdx(i);
        return;
      }
    }
  }
  // 添加保护
  td_.opInfo.set_bandIdx(0);
}

bool FlashAttentionScoreGradTilingUs1s2Bbn2::IsModuleOneShape() {
  gert::Shape input_shape_5d = {td_.opInfo.get_B(), td_.opInfo.get_N2() * td_.opInfo.get_G(), td_.opInfo.get_S1(),
                                td_.opInfo.get_S2(), td_.opInfo.get_D()};
  return !(std::find(MODULE1_TEMPLATE_SUPPORTED_SHAPE.begin(), MODULE1_TEMPLATE_SUPPORTED_SHAPE.end(),
                     input_shape_5d) == MODULE1_TEMPLATE_SUPPORTED_SHAPE.end());
}

ge::graphStatus FlashAttentionScoreGradTilingUs1s2Bbn2::CheckOutOfTokens(const uint32_t s1, const uint32_t s2) {
  int32_t negPretMax = sparseMode == NO_MASK ? s2 : s1;
  int32_t negNexttMax = sparseMode == NO_MASK ? s1 : s2;
  if (-td_.opInfo.get_preTokens() > negPretMax || -td_.opInfo.get_nextTokens() > negNexttMax ||
      (int64_t(td_.opInfo.get_preTokens()) + int64_t(td_.opInfo.get_nextTokens())) < 0) {
    OPS_LOG_E(context_, "FAG Us1s2Bbn2 with s1[%u] s2[%u] preToken[%d] nextToken[%d] is not support!",
              s1, s2, td_.opInfo.get_preTokens(), td_.opInfo.get_nextTokens());
    return ge::GRAPH_FAILED;
  }
  return ge::GRAPH_SUCCESS;
}

ge::graphStatus FlashAttentionScoreGradTilingUs1s2Bbn2::CheckTokens() {
  bool enableTokens = sparseMode == NO_MASK || sparseMode == BAND || sparseMode == RIGHT_DOWN_CASUAL_BAND ||
                      sparseMode == BAND_LEFT_UP_CASUAL;
  if (!enableTokens) {
    return ge::GRAPH_SUCCESS;
  }

  uint32_t realS1 = tmpData_.s1;
  uint32_t realS2 = tmpData_.s2;
  if (td_.opInfo.get_layout() == static_cast<uint32_t>(InputLayout::TND)) {
    if (sparseMode == NO_MASK || sparseMode == BAND) {
      for (size_t i = 0; i < tmpData_.b; i++) {
        realS1 = tmpData_.actualSeqQlen[i];
        realS2 = tmpData_.actualSeqKvlen[i];
        auto status = CheckOutOfTokens(realS1, realS2);
        if (status != ge::GRAPH_SUCCESS) {
          return status;
        }
      }
      return ge::GRAPH_SUCCESS;
    }

    realS1 = tmpData_.actualSeqQlen[td_.opInfo.get_bandIdx()];
    realS2 = tmpData_.actualSeqKvlen[td_.opInfo.get_bandIdx()];
  }
  return CheckOutOfTokens(realS1, realS2);
}

ge::graphStatus FlashAttentionScoreGradTilingUs1s2Bbn2::GetShapeAttrsInfo() {
  /*
  Get all shape info and attr
  */
  auto status = GetLayoutInfo();
  if (status != ge::GRAPH_SUCCESS) {
    return status;
  }

  status = GetBaseShapeInfo();
  if (status != ge::GRAPH_SUCCESS) {
    return status;
  }

  DecideBranch();

  // setTilingData
  td_.opInfo.set_scaleValue(*context_->GetAttrs()->GetAttrPointer<float>(SCALE_VALUE));
  td_.opInfo.set_keepProb(*context_->GetAttrs()->GetAttrPointer<float>(KEEP_PROB));
  if (td_.opInfo.get_keepProb() < 1) {
    tmpData_.drop_out_cfg = DropOutConfig::EXIST_DROP_OUT;
  }

  // 新增SPARSE_MODE属性，上库兼容处理
  auto attrs = context_->GetAttrs();
  if (attrs->GetAttrNum() > static_cast<size_t>(SPARSE_MODE)) {
    sparseMode = SparseMode(*(attrs->GetAttrPointer<int>(SPARSE_MODE))); // 7
  }

  td_.opInfo.set_preTokens(*context_->GetAttrs()->GetAttrPointer<int>(PRE_TOKENS));
  td_.opInfo.set_nextTokens(*context_->GetAttrs()->GetAttrPointer<int>(NEXT_TOKENS));

  SetBandIdx();
  status = CheckTokens();
  if (status != ge::GRAPH_SUCCESS) {
    return status;
  }

  status = GetAttenMaskInfo();
  if (status != ge::GRAPH_SUCCESS) {
    return status;
  }

  status = SetSparseParams();
  if (status != ge::GRAPH_SUCCESS) {
    return status;
  }

  if ((td_.opInfo.get_isSparse() == 0) && IsModuleOneShape() && context_->GetDeterministic() != 1) {
    OPS_LOG_D(context_, "Sparse disable. FlashAttentionScoreGradTilingUs1s2Bbn2 not support this case");
    return ge::GRAPH_PARAM_INVALID;
  }

  OPS_LOG_D(context_, "FAG Us1s2Bbn2 sparse %s.", td_.opInfo.get_isSparse() == 0 ? "disable" : "enable");

  status = GetPseInfo();
  if (status != ge::GRAPH_SUCCESS) {
    return status;
  }
  return ge::GRAPH_SUCCESS;
}

void FlashAttentionScoreGradTilingUs1s2Bbn2::NMDStrategy() {
  // SBH: B*H > 65535, 本身搬运效率太差，不调整S1/S2方向配比，使用默认参数
  if (tmpData_.layout == static_cast<uint32_t>(InputLayout::SBH) &&
      tmpData_.b * tmpData_.n2 * tmpData_.d > MAX_STRIDE_LIMIT) {
      return;
  }
  // 优先S2方向配比，S2方向不够时，S1方向配比，但保证S1和S2方向的配比不超过MM_RATIO
  s1Ratio = 1;
  s2Ratio = CeilCommon(td_.opInfo.get_S2(), baseNmm);
  s2Ratio = s2Ratio > mmRatio ? mmRatio : s2Ratio;
  s1Ratio = mmRatio / s2Ratio;
  // 根据实际S1长度调整配比
  s1Ratio = s1Ratio * baseMmm > td_.opInfo.get_S1() ? CeilCommon(td_.opInfo.get_S1(), baseMmm) : s1Ratio;

  // 强行保护s1Ratio、s2Ratio
  s1Ratio = s1Ratio > 0 ? s1Ratio : 1;
  s2Ratio = s2Ratio > 0 ? s2Ratio : 1;
  OPS_LOG_I(context_, "FlashAttentionScoreGradTilingUs1s2Bbn2 s1Ratio: %u, s2Ratio: %u", s1Ratio, s2Ratio);
}

void FlashAttentionScoreGradTilingUs1s2Bbn2::VectorBaseMNSplit() {
  // Return Vector baseMN
  // dimS2(baseN) is lastAxis in ub need upper align.
  if (dimS2 < baseN && dimS1 > baseM) {
    baseN = AlignData(dimS2, MATMUL_MINI_N);
    // 16对齐，减少冗余数据；如果有dropOut和atten的时候会修正为32对其；
    baseM = tensorSize / baseN / (MATMUL_MINI_N) * (MATMUL_MINI_N);
    baseM = baseM > VEC_REPEAT_LIMIT ? (BASE_LEN_256 - (MATMUL_MINI_N * 2)) : baseM;
  } else if (dimS2 > baseN && dimS1 < baseM) {
    baseM = dimS1;
    baseN = tensorSize / baseM / MATMUL_MINI_N * MATMUL_MINI_N;
  }

  baseM = baseM > BASE_LEN_128 ? BASE_LEN_128 : baseM;  // simple-softmax only had 4KB UB
  baseN = baseN > BASE_LEN_512 ? BASE_LEN_512 : baseN;  // nz-format only support 512

  bool exist_dropout = tmpData_.drop_out_cfg == DropOutConfig::EXIST_DROP_OUT;
  bool exist_attenmask = tmpData_.atten_mask_cfg == AttenMaskConfig::EXIST_ATTEN_MASK;
  if (exist_dropout || exist_attenmask) {
    // dropout and atten-mask only had 8K, baseN need align to 32
    baseM = tensorSize / (CeilCommon(baseN, BLOCK) * BLOCK);
    baseN = CeilCommon(baseN, BLOCK) * BLOCK;
    baseM = baseM > BASE_LEN_128 ? BASE_LEN_128 : baseM;  // simple-softmax only had 4KB UB
    baseN = baseN > BASE_LEN_512 ? BASE_LEN_512 : baseN;  // nz-format only support 512
  }
  return;
}

void FlashAttentionScoreGradTilingUs1s2Bbn2::SFTBaseMDSplit() {
  // Return SFT-BaseM SFT-BaseD
  uint32_t blockNum = BLOCK / B16;
  uint32_t dimDAlign = CeilCommon(dimD, blockNum) * blockNum;

  // step0: 确定存储大小
  sftSingleM = BASE_LEN_256;
  if (singleM < sftSingleM) {
    sftSingleM = singleM;
  }

  // step1: 确定D是否需要切分
  if (dimDAlign > tensorSize) {
    // template
    sftSingleM = BASE_LEN_128;
    sftBaseM = BASE_LEN_128;
    baseD = BASE_LEN_64;
    needSplitD = true;
  } else {
    baseD = dimDAlign;
    needSplitD = false;
  }

  // step2: 确定sftBaseM
  sftBaseM = tensorSize / baseD;
  if (sftBaseM > sftSingleM) {
    sftBaseM = sftSingleM;
  }
}

void FlashAttentionScoreGradTilingUs1s2Bbn2::DecideBaseMND() {
  /*
  Func:
  0. 确定SFT-BaseM
  1. 确定VECTOR-BaseMN
  2. 确定MATMUL-BaseMN
  3. 确定SingleMN(BMM与VECTOR共有参数)
  4. 确定D通道切分(受限于新算法, D通道切分依赖于Matmul-BaseM)
     eg: baseMmm = 128, baseD <= 64
  */
  // Init
  baseM = BASE_LEN_64;
  baseN = BASE_LEN_128;
  baseMmm = BASE_LEN_128;
  baseNmm = BASE_LEN_128;
  tensorSize = baseM * baseN;
  mmRatio = MM_RATIO_4;
  dimD = td_.opInfo.get_D();
  dimS2 = td_.opInfo.get_S2();
  dimS1 = td_.opInfo.get_S1();

  VectorBaseMNSplit();

  // Matmul特殊处理（cacheline友好）
  if (dimD <= BASE_LEN_64 && dimS2 >= BASE_LEN_256) {
    baseMmm = BASE_LEN_128;
    baseNmm = BASE_LEN_256;
    baseM = BASE_LEN_32;
    baseN = BASE_LEN_256;
  } else {
    baseMmm = BASE_LEN_128;
    baseNmm = BASE_LEN_128;
  }

  // 部分场景需要调整配比
  s1Ratio = 1;
  s2Ratio = mmRatio;
  NMDStrategy();
  singleM = baseMmm * s1Ratio;
  singleN = baseNmm * s2Ratio;

  SFTBaseMDSplit();
}

ge::graphStatus FlashAttentionScoreGradTilingUs1s2Bbn2::CheckAttenMaskShape() {
  // check atten_mask shape when enable atten_mask_compress
  if (attenMaskCompressMode == NO_COMPRESS_MODE) {
    bool invalid =
        td_.opInfo.get_existAttenMask() != 0 && tmpData_.layout != static_cast<uint32_t>(InputLayout::TND) &&
        ((int64_t)attenMaskS1Size * (int64_t)attenMaskS2Size < (int64_t)tmpData_.s1 * (int64_t)tmpData_.s2);
    if (invalid) {
      OPS_LOG_E("CheckAttenMaskShape", "atten mask shape [%u,%u] is invalid.", attenMaskS1Size, attenMaskS2Size);
      return ge::GRAPH_FAILED;
    }
    return ge::GRAPH_SUCCESS;
  }

  if (attenMaskCompressMode == PREFIX_COMPRESS_MODE) {
    if (attenMaskS1Size != PREFIX_COMPRESS_S1_SIZE || attenMaskS2Size != ATTEN_MASK_COMPRESS_LIMIT) {
      OPS_LOG_E("Atten Mask Compress",
                "atten mask shape for prefix compress mode is invalid, try setting it to [3072, 2048].");
      return ge::GRAPH_FAILED;
    }
    return ge::GRAPH_SUCCESS;
  }

  if (attenMaskS1Size != attenMaskS2Size) {
    OPS_LOG_E("Atten Mask Compress", "atten mask shape is not square.");
    return ge::GRAPH_FAILED;
  }
  if (attenMaskS2Size != ATTEN_MASK_COMPRESS_LIMIT) {
    OPS_LOG_E(context_, "atten mask shape is invalid, try setting it to [2048, 2048].");
    return ge::GRAPH_FAILED;
  }

  return ge::GRAPH_SUCCESS;
}

bool FlashAttentionScoreGradTilingUs1s2Bbn2::CheckForDichotomy(std::vector<int64_t>& nums, uint32_t x, uint32_t m) {
  int64_t sum = 0;
  uint32_t cnt = 1;
  for (long unsigned int i = 0; i < nums.size(); i++) {
    if (sum + nums[i] > x) {
      cnt++;
      sum = nums[i];
    } else {
      sum += nums[i];
    }
  }
  return (cnt <= m);
}

int64_t FlashAttentionScoreGradTilingUs1s2Bbn2::GetSplitArrayMinMaxSum(std::vector<int64_t>& s1s2WeightNums,
                                                                       uint32_t coreNum) {
  int64_t left = 0;
  int64_t right = 0;
  for (long unsigned int i = 0; i < s1s2WeightNums.size(); i++) {
      right += s1s2WeightNums[i];
      if (left < s1s2WeightNums[i]) {
          left = s1s2WeightNums[i];
      }
  }
  while (left < right) {
      int64_t mid = (left + right) >> 1;
      if (CheckForDichotomy(s1s2WeightNums, mid, coreNum)) {
          right = mid;
      } else {
          left = mid + 1;
      }
  }
  return left;
}

ge::graphStatus FlashAttentionScoreGradTilingUs1s2Bbn2::DoBlockTiling() {
  int64_t nNums = td_.opInfo.get_B() * td_.opInfo.get_N2();
  int64_t formerCoreProcessNNums = CeilCommon(nNums, aicoreParams_.blockDim);
  int64_t remainCoreNum = formerCoreProcessNNums * aicoreParams_.blockDim - nNums;
  td_.opInfo.set_usedCoreNum(aicoreParams_.blockDim);
  if (td_.opInfo.get_layout() == static_cast<uint32_t>(InputLayout::TND)) {
    if (nNums < static_cast<int64_t>(aicoreParams_.blockDim)) {
      // 需要给出真实使用的core数，用于bN2idxStarts，bN2idxEnds的检索
      td_.opInfo.set_usedCoreNum(nNums);
    }
  }
  td_.opInfo.set_formerCoreNum(aicoreParams_.blockDim - remainCoreNum);
  td_.opInfo.set_formerCoreProcessNNum(formerCoreProcessNNums);
  td_.opInfo.set_remainCoreProcessNNum(formerCoreProcessNNums - 1);
  td_.opInfo.set_castUsedCoreNum(aicoreParams_.blockDim);
  if (nNums < static_cast<int64_t>(aicoreParams_.blockDim)) {
    td_.opInfo.set_castUsedCoreNum(nNums);
  }

  // 确定性计算下的TND对应的分核策略
  if (td_.opInfo.get_layout() == static_cast<uint32_t>(InputLayout::TND)) {
    if (nNums <= static_cast<int64_t>(aicoreParams_.blockDim)) {
      for (uint32_t i = 0; i < nNums; i++) {
        tmpData_.bN2idxStarts[i] = i;
        tmpData_.bN2idxEnds[i] = i;
      }
    } else {
      int64_t minMaxSum = GetSplitArrayMinMaxSum(tmpData_.s1s2Weight, aicoreParams_.blockDim);
      OPS_LOG_D(context_->GetNodeName(),
                "FlashAttentionScoreGradTilingUs1s2Bbn2 splite B*N2 minMaxSum: %ld.", minMaxSum);
      int64_t tempSum = 0;
      uint32_t coreIdxBegin = 0;
      uint32_t coreIdxEnd = 0;
      uint32_t tempSumIdx = 0;
      tmpData_.bN2idxStarts[coreIdxBegin++] = 0;
      for (uint32_t i = 0; i < tmpData_.s1s2Weight.size(); i++) {
          tempSum += tmpData_.s1s2Weight[i];
          if ((tempSum > minMaxSum) || (aicoreParams_.blockDim - (tempSumIdx + 1) >= (tmpData_.s1s2Weight.size() - i))) {
              tempSumIdx++;
              tempSum = tmpData_.s1s2Weight[i];
              tmpData_.bN2idxEnds[coreIdxEnd++] = i - 1;
              tmpData_.bN2idxStarts[coreIdxBegin++] = i;
              if (aicoreParams_.blockDim - tempSumIdx >= (tmpData_.s1s2Weight.size() - i)) {
                coreIdxBegin--;
                for (uint32_t k = 0; k < tmpData_.s1s2Weight.size() - i; k++) {
                  tmpData_.bN2idxStarts[coreIdxBegin++] = i + k;
                  tmpData_.bN2idxEnds[coreIdxEnd++] = i + k;
                }
                coreIdxEnd--;
              break;
              }
          }
      }
      tmpData_.bN2idxEnds[coreIdxEnd] = tmpData_.s1s2Weight.size() - 1;
    }
    td_.tndSplitCoreParams.set_bN2idxStarts(tmpData_.bN2idxStarts);
    td_.tndSplitCoreParams.set_bN2idxEnds(tmpData_.bN2idxEnds);
  }

  // debugInfo
  OPS_LOG_D(context_->GetNodeName(), "FlashAttentionScoreGradTilingUs1s2Bbn2 splite B*N2 usedCoreNum: %u.",
            td_.opInfo.get_usedCoreNum());
  for (uint32_t i = 0; i < td_.opInfo.get_usedCoreNum(); i++) {
    OPS_LOG_D(context_->GetNodeName(),
              "FlashAttentionScoreGradTilingUs1s2Bbn2 splite B*N2 CoreIdx: %u begin: %u end: %u.",
              i, td_.tndSplitCoreParams.get_bN2idxStarts()[i], td_.tndSplitCoreParams.get_bN2idxEnds()[i]);
  }

  return ge::GRAPH_SUCCESS;
}

ge::graphStatus FlashAttentionScoreGradTilingUs1s2Bbn2::DoOpTiling() {
  OPS_LOG_I(context_, "FlashAttentionScoreGradTilingUs1s2Bbn2 DoTiling start");

  // 确定基本块大小
  DecideBaseMND();

  // 确定主kernel-UBLength
  auto inputBufferLen = tensorSize * B32;
  td_.opInfo.set_inputBufferLen(inputBufferLen);
  td_.opInfo.set_helpBufferLen(inputBufferLen);

  // setTilingData
  td_.splitCoreParams.set_SFTBaseM(sftBaseM);
  td_.splitCoreParams.set_SFTSingleM(sftSingleM);
  td_.splitCoreParams.set_dInner(baseD);
  td_.splitCoreParams.set_baseM(baseM);
  td_.splitCoreParams.set_baseN(baseN);
  td_.splitCoreParams.set_singleM(singleM);
  td_.splitCoreParams.set_singleN(singleN);
  td_.splitCoreParams.set_s1OuterOuter(CeilCommon(td_.opInfo.get_S1(), singleM));
  td_.splitCoreParams.set_s2OuterOuter(CeilCommon(td_.opInfo.get_S2(), singleN));

  auto status = DoBlockTiling();
  if (status != ge::GRAPH_SUCCESS) {
    return status;
  }

  status = DoCastTiling();
  if (status != ge::GRAPH_SUCCESS) {
    return status;
  }

  status = DropoutTiling();
  if (status != ge::GRAPH_SUCCESS) {
    return status;
  }

  if (CheckAttenMaskShape() != ge::GRAPH_SUCCESS) {
    OPS_LOG_I(context_, "FAG Us1s2Bbn2 with shape b[%u] n2[%u] g[%u] s1[%u] s2[%u] d[%u] preToken[%d] nextToken[%d]!",
              td_.opInfo.get_B(), td_.opInfo.get_N2(), td_.opInfo.get_G(), td_.opInfo.get_S1(), td_.opInfo.get_S2(),
              td_.opInfo.get_D(), td_.opInfo.get_preTokens(), td_.opInfo.get_nextTokens());
    return ge::GRAPH_FAILED;
  }
  return ge::GRAPH_SUCCESS;
}

ge::graphStatus FlashAttentionScoreGradTilingUs1s2Bbn2::DoLibApiTiling() {
  // calc for simpleSoftMax which dstShape is as same as srcShape
  auto simpleSoftMaxShape = ge::Shape({baseM, baseN});
  auto helpLenA = tensorSize * tmpData_.dataTypeSize;  // UB内数据类型
  AscendC::SoftMaxTilingFunc(simpleSoftMaxShape, sizeof(float), helpLenA, td_.softmaxTilingData);

  // calc for softmaxGrad
  // calcByFp32 is [sftBaseM, 8], calcByFP16 is [sftBaseM, 16]
  auto softmaxGradShape = ge::Shape({sftBaseM, BLOCK / tmpData_.dataTypeSize});
  auto helpLenB = 2 * tensorSize * tmpData_.dataTypeSize;  // UB内数据类型 64KB
  AscendC::SoftMaxGradTilingFunc(softmaxGradShape, tmpData_.dataTypeSize, helpLenB, td_.softmaxGradTilingData, true);

  // calc BMM
  auto status = SetBmm1TilingData(baseMmm, baseNmm, aicoreParams_.l1Size);
  if (status != ge::GRAPH_SUCCESS) {
    return status;
  }

  status = SetBmm31TilingData(baseMmm, aicoreParams_.l1Size);
  if (status != ge::GRAPH_SUCCESS) {
    return status;
  }

  status = SetBmm4TilingData(baseMmm, baseNmm, aicoreParams_.l1Size);
  if (status != ge::GRAPH_SUCCESS) {
    return status;
  }

  return ge::GRAPH_SUCCESS;
}

ge::graphStatus FlashAttentionScoreGradTilingUs1s2Bbn2::DoCastTiling() {
  // query
  uint64_t allNumQuery = static_cast<uint64_t>(td_.opInfo.get_B()) * td_.opInfo.get_N2()
      * td_.opInfo.get_G() * td_.opInfo.get_S1() * td_.opInfo.get_D();
  // TND时候要按照真实的query的num数计算
  if (td_.opInfo.get_layout() == static_cast<uint32_t>(InputLayout::TND)) {
    allNumQuery =
      static_cast<uint64_t>(tmpData_.t1) * td_.opInfo.get_N2() * td_.opInfo.get_G() * td_.opInfo.get_D();
  }

  // K V
  uint64_t allNumKv = static_cast<uint64_t>(td_.opInfo.get_B()) * td_.opInfo.get_N2()
      * td_.opInfo.get_S2() * td_.opInfo.get_D();
  // TND时候要按照真实的k，v的num数计算
  if (td_.opInfo.get_layout() == static_cast<uint32_t>(InputLayout::TND)) {
    allNumKv =
      static_cast<uint64_t>(tmpData_.t2) * td_.opInfo.get_N2() * 1 * td_.opInfo.get_D();
  }

  uint32_t used_core_num = td_.opInfo.get_castUsedCoreNum();

  uint32_t postUbBaseSize = (aicoreParams_.ubSize) / POST_COEX_NODE / BUFFER_NUM / BASE_LEN_256 * BASE_LEN_256;
  uint32_t qPostBaseNum = postUbBaseSize / B16;
  uint32_t qPostBlockTotal = allNumQuery;
  uint64_t qSizeAlign = (qPostBlockTotal + BASE_LEN_256 - 1) / GM_ALIGN * GM_ALIGN * B16;
  uint32_t qPostTailNumTmp = qPostBlockTotal % qPostBaseNum;
  uint32_t qPostTailNum = qPostTailNumTmp == 0 ? qPostBaseNum : qPostTailNumTmp;
  uint32_t qPostBlockOuterTotal = (qPostBlockTotal + qPostBaseNum - 1) / qPostBaseNum;
  uint32_t qPostBlockFactor = (qPostBlockOuterTotal + used_core_num - 1) / used_core_num;

  uint32_t kvPostBaseNum = postUbBaseSize / B16;
  uint32_t kvPostBlockTotal = allNumKv;
  uint64_t kvSizeAlign = (kvPostBlockTotal + GM_ALIGN - 1) / GM_ALIGN * GM_ALIGN * B16;
  uint32_t kvPostTailNumTmp = kvPostBlockTotal % kvPostBaseNum;
  uint32_t kvPostTailNum = kvPostTailNumTmp == 0 ? kvPostBaseNum : kvPostTailNumTmp;
  uint32_t kvPostBlockOuterTotal = (kvPostBlockTotal + kvPostBaseNum - 1) / kvPostBaseNum;
  uint32_t kvPostBlockFactor = (kvPostBlockOuterTotal + used_core_num - 1) / used_core_num;

  td_.postTilingData.set_coreNum(used_core_num);
  td_.postTilingData.set_scaleValue(td_.opInfo.get_scaleValue());
  td_.postTilingData.set_postUbBaseSize(postUbBaseSize);
  td_.postTilingData.set_qPostBlockFactor(qPostBlockFactor);
  td_.postTilingData.set_qPostBlockTotal(qPostBlockTotal);
  td_.postTilingData.set_qPostBaseNum(qPostBaseNum);
  td_.postTilingData.set_qPostTailNum(qPostTailNum);
  td_.postTilingData.set_qSizeAlign(qSizeAlign);

  td_.postTilingData.set_kvPostBlockFactor(kvPostBlockFactor);
  td_.postTilingData.set_kvPostBlockTotal(kvPostBlockTotal);
  td_.postTilingData.set_kvPostBaseNum(kvPostBaseNum);
  td_.postTilingData.set_kvPostTailNum(kvPostTailNum);
  td_.postTilingData.set_kvSizeAlign(kvSizeAlign);

  td_.opInfo.set_dqWorkspaceLen((allNumQuery * B32 + GM_ALIGN - 1) / GM_ALIGN * GM_ALIGN);
  td_.opInfo.set_dkWorkspaceLen((allNumKv * B32 + GM_ALIGN - 1) / GM_ALIGN * GM_ALIGN);
  td_.opInfo.set_dvWorkspaceLen((allNumKv * B32 + GM_ALIGN - 1) / GM_ALIGN * GM_ALIGN);

  return ge::GRAPH_SUCCESS;
}

ge::graphStatus FlashAttentionScoreGradTilingUs1s2Bbn2::DropoutTiling() {
  // 120KB FP16Tensor, 60KB U8Tensor, 8KB MaskTensor, 512B HelpTensor which less than UB(192KB).
  // singleUBProcessNum: UB最大处理FP16数据大小，需保证能被128整除.
  uint32_t castBufferLen = 60 * 1024;
  uint32_t outputBufferLen = 30 * 1024;
  uint32_t inputBufferLen = 4 * 1024;
  uint32_t singleUBProcessNum = castBufferLen / NUMBER_2;

  uint64_t dropoutWorkspaceLen = 0;

  // TND layout对应的每个B下面的S2分别对baseN进行求余，判断是否可以整除
  if (td_.opInfo.get_layout() == static_cast<uint32_t>(InputLayout::TND)) {
    bool dropoutIsDivisibleByBaseN = true;
    if (td_.opInfo.get_keepProb() < 1) {
      for (size_t i = 0; i < UNPAD_S1S2_BATCH_MAX_SIZE; i++) {
        if (tmpData_.actualSeqKvlen[i] % baseN != 0) {
          dropoutIsDivisibleByBaseN = false;
          break;
        }
      }
      if (dropoutIsDivisibleByBaseN == false) {
        dropoutWorkspaceLen = static_cast<uint64_t>(tmpData_.sumS1S2Product);
        dropoutWorkspaceLen *= static_cast<uint64_t>(td_.opInfo.get_N2());
        dropoutWorkspaceLen *= static_cast<uint64_t>(td_.opInfo.get_G());
        dropoutWorkspaceLen = (dropoutWorkspaceLen + BLOCK - 1) / BLOCK * BLOCK;
        dropoutWorkspaceLen = (dropoutWorkspaceLen + GM_ALIGN - 1) / GM_ALIGN * GM_ALIGN;
      }
    }
  } else if (td_.opInfo.get_keepProb() < 1 && td_.opInfo.get_S2() % baseN != 0) {
    dropoutWorkspaceLen = td_.opInfo.get_B();
    dropoutWorkspaceLen *= td_.opInfo.get_N2();
    dropoutWorkspaceLen *= td_.opInfo.get_G();
    dropoutWorkspaceLen *= td_.opInfo.get_S1();
    dropoutWorkspaceLen *= td_.opInfo.get_S2();
    dropoutWorkspaceLen = (dropoutWorkspaceLen + BLOCK - 1) / BLOCK * BLOCK;
    dropoutWorkspaceLen = (dropoutWorkspaceLen + GM_ALIGN - 1) / GM_ALIGN * GM_ALIGN;
  }
  td_.dropoutCastParams.set_dropoutWorkspaceLen(dropoutWorkspaceLen);
  td_.dropoutCastParams.set_castBufferLen(castBufferLen);
  td_.dropoutCastParams.set_outputBufferLen(outputBufferLen);
  td_.dropoutCastParams.set_inputBufferLen(inputBufferLen);
  td_.dropoutCastParams.set_singleUBProcessNum(singleUBProcessNum);

  return ge::GRAPH_SUCCESS;
}

ge::graphStatus FlashAttentionScoreGradTilingUs1s2Bbn2::GetWorkspaceSize() {
  // Tiling传递的内存大小、起始地址，统一为字节数，单位为B
  auto blockdim = CalcTschBlockDim(aicoreParams_.blockDim, aicoreParams_.aicNum, aicoreParams_.blockDim);
  OPS_ERR_IF(blockdim == 0,
             OPS_REPORT_VECTOR_INNER_ERR("FlashAttentionScoreGradTilingUs1s2Bbn2",
                                              "blockdim is 0, aicNum is %ld, aivNum is %ld.", aicoreParams_.aicNum,
                                              aicoreParams_.blockDim),
             return ge::GRAPH_FAILED);
  context_->SetBlockDim(blockdim);

  // 系统预留
  uint64_t sysLen = WORK_SPACE_BASE_CAL;
  // dropout软同步
  uint64_t syncDropoutLen = AlignData(3200, GM_ALIGN);
  // 清0软同步
  uint64_t syncAtomicLen = AlignData(3200, GM_ALIGN);
  // Cast软同步
  uint64_t syncCastLen = AlignData(3200, GM_ALIGN);
  // dropout output
  uint64_t dropoutWorkspaceLen = td_.dropoutCastParams.get_dropoutWorkspaceLen();

  // bmm1/bmm2 output
  uint64_t mm1WorkspaceLen = singleM * singleN * tmpData_.dataTypeSize;
  mm1WorkspaceLen = AlignData(mm1WorkspaceLen, GM_ALIGN);
  uint64_t mm2WorkspaceLen = singleM * singleN * tmpData_.dataTypeSize;
  mm2WorkspaceLen = AlignData(mm2WorkspaceLen, GM_ALIGN);

  // bmm4/bmm3.1/bmm3.2 input
  uint64_t typeSize = tmpData_.queryType == ge::DT_FLOAT ? B32 : B16;
  uint64_t space = singleM * singleN * typeSize;
  space = AlignData(space, GM_ALIGN);
  // 输出matmul异步需要做PingPang防止被踩
  uint64_t mm4InputWorkspaceLen = space + space;
  uint64_t mm3InputWorkspaceLen = space + space;

  // dqCast
  uint64_t dqWorkspaceLen = td_.opInfo.get_dqWorkspaceLen();
  // dkCast
  uint64_t dkWorkspaceLen = td_.opInfo.get_dkWorkspaceLen();
  // dvCast
  uint64_t dvWorkspaceLen = td_.opInfo.get_dvWorkspaceLen();

  // set global workspace
  // 内存顺序排布
  size_t* workspaces = context_->GetWorkspaceSizes(1);
  workspaces[0] = sysLen + syncDropoutLen + syncAtomicLen + syncCastLen;
  workspaces[0] += dropoutWorkspaceLen;
  workspaces[0] += (mm1WorkspaceLen + mm2WorkspaceLen) * td_.opInfo.get_usedCoreNum();
  workspaces[0] += (mm4InputWorkspaceLen + mm3InputWorkspaceLen) * td_.opInfo.get_usedCoreNum();
  workspaces[0] += dqWorkspaceLen + dkWorkspaceLen + dvWorkspaceLen;

  // setTilingData
  td_.opInfo.set_mm1WorkspaceLen(mm1WorkspaceLen);
  td_.opInfo.set_mm2WorkspaceLen(mm2WorkspaceLen);
  td_.opInfo.set_mm4InputWorkspaceLen(mm4InputWorkspaceLen);
  td_.opInfo.set_mm3InputWorkspaceLen(mm3InputWorkspaceLen);
  td_.opInfo.set_syncDropoutLen(syncDropoutLen);
  td_.opInfo.set_syncAtomicLen(syncAtomicLen);
  td_.opInfo.set_syncCastLen(syncCastLen);
  
  uint64_t workspaceOffsets = td_.opInfo.get_syncDropoutLen() + td_.opInfo.get_syncAtomicLen() + td_.opInfo.get_syncCastLen();
  workspaceOffsets += td_.dropoutCastParams.get_dropoutWorkspaceLen() + (td_.opInfo.get_mm1WorkspaceLen() +
                      td_.opInfo.get_mm2WorkspaceLen()) * td_.opInfo.get_usedCoreNum();
  workspaceOffsets += (td_.opInfo.get_mm4InputWorkspaceLen() + td_.opInfo.get_mm3InputWorkspaceLen()) * td_.opInfo.get_usedCoreNum();

  td_.postTilingData.set_dqWorkSpaceOffset(workspaceOffsets);

  workspaceOffsets = workspaceOffsets + td_.opInfo.get_dqWorkspaceLen();
  td_.postTilingData.set_dkWorkSpaceOffset(workspaceOffsets);
  
  workspaceOffsets = workspaceOffsets + td_.opInfo.get_dkWorkspaceLen();
  td_.postTilingData.set_dvWorkSpaceOffset(workspaceOffsets);

  return ge::GRAPH_SUCCESS;
}

ge::graphStatus FlashAttentionScoreGradTilingUs1s2Bbn2::PostTiling() {
  td_.SaveToBuffer(context_->GetRawTilingData()->GetData(), context_->GetRawTilingData()->GetCapacity());
  context_->GetRawTilingData()->SetDataSize(td_.GetDataSize());

  return ge::GRAPH_SUCCESS;
}

ge::graphStatus FlashAttentionScoreGradTilingUs1s2Bbn2::GetLayoutInfo() {
  const char* inputLayout = context_->GetAttrs()->GetAttrPointer<char>(INPUT_LAYOUT);
  if (strcmp(inputLayout, BSH_STR) == 0) {
    td_.opInfo.set_layout(static_cast<uint32_t>(InputLayout::BSH));
  } else if (strcmp(inputLayout, SBH_STR) == 0) {
    td_.opInfo.set_layout(static_cast<uint32_t>(InputLayout::SBH));
  } else if (strcmp(inputLayout, BNSD_STR) == 0) {
    td_.opInfo.set_layout(static_cast<uint32_t>(InputLayout::BNSD));
  } else if (strcmp(inputLayout, BSND_STR) == 0) {
    td_.opInfo.set_layout(static_cast<uint32_t>(InputLayout::BSND));
  } else if (strcmp(inputLayout, TND_STR) == 0 && context_->GetDeterministic() == 1) {
    td_.opInfo.set_layout(static_cast<uint32_t>(InputLayout::TND));
  } else {
    OPS_LOG_I(context_, "FlashAttentionScoreGradTilingUs1s2Bbn2 unsupported layout");
    return ge::GRAPH_PARAM_INVALID;
  }
  tmpData_.layout = td_.opInfo.get_layout();
  return ge::GRAPH_SUCCESS;
}

ge::graphStatus FlashAttentionScoreGradTilingUs1s2Bbn2::SetBaseInfo(
    const gert::Shape& queryShape, const gert::Shape& keyShape, uint32_t dimN1) {
  if (td_.opInfo.get_layout() == static_cast<uint32_t>(InputLayout::BNSD)) {
    if (queryShape.GetDimNum() != DIMS_FOUR || keyShape.GetDimNum() != DIMS_FOUR) {
      return ge::GRAPH_PARAM_INVALID;
    }
    td_.opInfo.set_B(queryShape.GetDim(DIM_0));
    td_.opInfo.set_G(queryShape.GetDim(DIM_1) / keyShape.GetDim(DIM_1));
    td_.opInfo.set_N2(keyShape.GetDim(DIM_1));
    td_.opInfo.set_S1(queryShape.GetDim(DIM_2));
    td_.opInfo.set_S2(keyShape.GetDim(DIM_2));
    td_.opInfo.set_D(queryShape.GetDim(DIM_3));
  } else if (td_.opInfo.get_layout() == static_cast<uint32_t>(InputLayout::BSND)) {
    if (queryShape.GetDimNum() != DIMS_FOUR || keyShape.GetDimNum() != DIMS_FOUR) {
      return ge::GRAPH_PARAM_INVALID;
    }
    td_.opInfo.set_B(queryShape.GetDim(DIM_0));
    td_.opInfo.set_G(queryShape.GetDim(DIM_2) / keyShape.GetDim(DIM_2));
    td_.opInfo.set_N2(keyShape.GetDim(DIM_2));
    td_.opInfo.set_S1(queryShape.GetDim(DIM_1));
    td_.opInfo.set_S2(keyShape.GetDim(DIM_1));
    td_.opInfo.set_D(queryShape.GetDim(DIM_3));
  } else if (td_.opInfo.get_layout() == static_cast<uint32_t>(InputLayout::BSH)) {
    if (queryShape.GetDimNum() != DIMS_THREE || keyShape.GetDimNum() != DIMS_THREE) {
      return ge::GRAPH_PARAM_INVALID;
    }
    uint32_t dimD = queryShape.GetDim(DIM_2) / dimN1;
    uint32_t dimN2 = keyShape.GetDim(DIM_2) / dimD;
    td_.opInfo.set_B(queryShape.GetDim(DIM_0));
    td_.opInfo.set_G(dimN1 / dimN2);
    td_.opInfo.set_N2(dimN2);
    td_.opInfo.set_S1(queryShape.GetDim(DIM_1));
    td_.opInfo.set_S2(keyShape.GetDim(DIM_1));
    td_.opInfo.set_D(dimD);
  } else if (td_.opInfo.get_layout() == static_cast<uint32_t>(InputLayout::SBH)) {
    uint32_t dimD = queryShape.GetDim(DIM_2) / dimN1;
    uint32_t dimN2 = keyShape.GetDim(DIM_2) / dimD;
    td_.opInfo.set_B(queryShape.GetDim(DIM_1));
    td_.opInfo.set_G(dimN1 / dimN2);
    td_.opInfo.set_N2(dimN2);
    td_.opInfo.set_S1(queryShape.GetDim(DIM_0));
    td_.opInfo.set_S2(keyShape.GetDim(DIM_0));
    td_.opInfo.set_D(dimD);
  } else if (td_.opInfo.get_layout() == static_cast<uint32_t>(InputLayout::TND)) {
    auto actualSeqQlenTensor = context_->GetOptionalInputTensor(ACTUAL_SEQ_Q_LEN);
    auto actualSeqKvlenTensor = context_->GetOptionalInputTensor(ACTUAL_SEQ_KV_LEN);
    if (actualSeqQlenTensor == nullptr || actualSeqKvlenTensor == nullptr) {
      OPS_LOG_D("actualSeqQlenTensor or actualSeqKvlenTensor is nullptr");
      return ge::GRAPH_PARAM_INVALID;
    }

    const size_t seqQShapeSize = actualSeqQlenTensor->GetShapeSize();
    const size_t kvSeqShapeSize = actualSeqKvlenTensor->GetShapeSize();
    if (seqQShapeSize != kvSeqShapeSize) {
      OPS_LOG_D("actualSeqQlenTensor shapeSize is not equal actualSeqKvlenTensor");
      return ge::GRAPH_PARAM_INVALID;
    }

    // unpad b 大小限制256
    if (seqQShapeSize > UNPAD_S1S2_BATCH_MAX_SIZE) {
      OPS_LOG_D("inputLayout = TND", "actualSeqQlenTensor shape size %zu is large than %d", seqQShapeSize,
                UNPAD_S1S2_BATCH_MAX_SIZE);
      return ge::GRAPH_PARAM_INVALID;
    }

    for (size_t i = 0; i < UNPAD_S1S2_BATCH_MAX_SIZE; i++) {
      tmpData_.actualSeqQlen[i] = 0;
      tmpData_.actualSeqKvlen[i] = 0;
    }

    std::vector<int64_t> actualSeqQlen;
    std::vector<int64_t> actualSeqKvlen;
    const int64_t* qValue = actualSeqQlenTensor->GetData<int64_t>();
    const int64_t* kvValue = actualSeqKvlenTensor->GetData<int64_t>();
    uint32_t tempN2 = keyShape.GetDim(DIM_1);
    for (size_t i = 0; i < seqQShapeSize; i++) {
      int64_t qSeqLen = (i == 0 ? qValue[i] : (qValue[i] - qValue[i - 1]));
      int64_t kvSeqLen = (i == 0 ? kvValue[i] : (kvValue[i] - kvValue[i - 1]));
      actualSeqQlen.push_back(qSeqLen);
      actualSeqKvlen.push_back(kvSeqLen);
      int64_t s1s2Product = actualSeqQlen[i] * actualSeqKvlen[i];
      tmpData_.sumS1S2Product += s1s2Product;

      // 将s1*s2的权重，按照B*N2的大小进行摊开存储，摊开顺序是先N2再B
      for (size_t n2 = 0; n2 < tempN2; n2++) {
        tmpData_.s1s2Weight.push_back(s1s2Product);
      }
    }

    std::copy(actualSeqQlen.begin(), actualSeqQlen.end(), tmpData_.actualSeqQlen);
    std::copy(actualSeqKvlen.begin(), actualSeqKvlen.end(), tmpData_.actualSeqKvlen);

    tmpData_.t1 = qValue[seqQShapeSize - 1];
    tmpData_.t2 = kvValue[seqQShapeSize - 1];
    td_.opInfo.set_B(seqQShapeSize);

    // query [t1, n1, d]   kv [t2, n2, d]   dy [t1, n1, d]
    td_.opInfo.set_G(queryShape.GetDim(DIM_1) / keyShape.GetDim(DIM_1));
    td_.opInfo.set_N2(keyShape.GetDim(DIM_1));
    td_.opInfo.set_S1(*std::max_element(actualSeqQlen.begin(), actualSeqQlen.end()));
    td_.opInfo.set_S2(*std::max_element(actualSeqKvlen.begin(), actualSeqKvlen.end()));
    td_.opInfo.set_D(queryShape.GetDim(DIM_2));
  } else {
    return ge::GRAPH_PARAM_INVALID;
  }
  tmpData_.b = td_.opInfo.get_B();
  tmpData_.n2 = td_.opInfo.get_N2();
  tmpData_.d = td_.opInfo.get_D();
  tmpData_.s1 = td_.opInfo.get_S1();
  tmpData_.s2 = td_.opInfo.get_S2();
  return td_.opInfo.get_layout() == static_cast<uint32_t>(InputLayout::TND)
             ? CheckTndShapeValid(context_, tmpData_.t1, dimN1, tmpData_.d)
             : CheckShapeValid(context_, tmpData_.b, dimN1, tmpData_.s1, tmpData_.d);
}

ge::graphStatus FlashAttentionScoreGradTilingUs1s2Bbn2::GetBaseShapeInfo() {
  const gert::Shape& queryShape = context_->GetInputShape(QUERY)->GetStorageShape();
  const gert::Shape& keyShape = context_->GetInputShape(KEY)->GetStorageShape();
  uint32_t dimN1 = *context_->GetAttrs()->GetAttrPointer<int>(HEAD_NUM);  // headNum is as same as N1
  if (SetBaseInfo(queryShape, keyShape, dimN1) != ge::GRAPH_SUCCESS) {
    return ge::GRAPH_PARAM_INVALID;
  }

  return ge::GRAPH_SUCCESS;
}

ge::graphStatus FlashAttentionScoreGradTilingUs1s2Bbn2::SetBmm1TilingData(uint32_t sOut, uint32_t sFla,
                                                                          uint32_t l1SizeRemain) {
  /*
  BSH/BSND:
  A：[B, S1, N1, D] + B: [B, S2, N2, D] = C[S1, S2]
  A\B In Gm, C In UB
  For A: m = S1, k = N1 * D
  For B: n = S2, K = N2 * D
  For C：don't care

  SBH:
  A: [S1, B, N1, D] + B: [S2, B, N2, D] = C[S1, S2]
  A\B In Gm, C In UB
  For A: m = S1, k = B * N1 * D
  For B: n = S2, k = B * N2 * D
  For C: don't case

  BNSD:
  m = S1, k = D, n = S2
  */
  uint32_t layout = td_.opInfo.get_layout();
  bool is_bsh = layout == static_cast<uint32_t>(InputLayout::BSH);
  bool is_sbh = layout == static_cast<uint32_t>(InputLayout::SBH);
  bool is_bsnd = layout == static_cast<uint32_t>(InputLayout::BSND);
  bool is_tnd = layout == static_cast<uint32_t>(InputLayout::TND);

  uint32_t ka = td_.opInfo.get_D();
  uint32_t kb = td_.opInfo.get_D();
  uint32_t n = td_.opInfo.get_S2();
  if (is_bsh || is_bsnd || is_tnd) {
    ka = td_.opInfo.get_G() * td_.opInfo.get_N2() * td_.opInfo.get_D();
    kb = td_.opInfo.get_N2() * td_.opInfo.get_D();
  } else if (is_sbh) {
    ka = td_.opInfo.get_B() * td_.opInfo.get_G() * td_.opInfo.get_N2() * td_.opInfo.get_D();
    kb = td_.opInfo.get_B() * td_.opInfo.get_N2() * td_.opInfo.get_D();
  }

  auto inputType = matmul_tiling::DataType::DT_FLOAT16;
  auto outputType = matmul_tiling::DataType::DT_FLOAT16;
  auto valueType = static_cast<uint32_t>(context_->GetInputDesc(VALUE)->GetDataType());
  if (valueType == ge::DT_BF16) {
    inputType = matmul_tiling::DataType::DT_BF16;
    outputType = matmul_tiling::DataType::DT_FLOAT;
  } else if (valueType == ge::DT_FLOAT) {
    inputType = matmul_tiling::DataType::DT_FLOAT;
    outputType = matmul_tiling::DataType::DT_FLOAT;
  } else if (valueType == ge::DT_FLOAT16 && tmpData_.branch == BRANCH_FP16_HIGH_PRECISION) {
    outputType = matmul_tiling::DataType::DT_FLOAT;
  }

  matmul_tiling::MatmulApiTiling mm1;
  mm1.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, inputType, false);
  mm1.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, inputType, true);
  mm1.SetCType(matmul_tiling::TPosition::VECCALC, matmul_tiling::CubeFormat::ND, outputType);
  mm1.SetShape(singleM, singleN, td_.opInfo.get_D());
  mm1.SetOrgShape(singleM, n, ka, kb);
  mm1.SetBias(false);
  mm1.SetBufferSpace(l1SizeRemain, aicoreParams_.l0cSize);
  mm1.SetFixSplit(sOut, sFla, -1);
  if (mm1.GetTiling(td_.mm1TilingData) == -1) {
    OPS_LOG_I(context_, "FlashAttentionScoreGradTilingUs1s2Bbn2 mm1 GetTiling Failed");
    return ge::GRAPH_PARAM_INVALID;
  }
  td_.mm1TilingData.set_shareMode(0);
  td_.mm1TilingData.set_shareL1Size(l1SizeRemain);
  td_.mm1TilingData.set_shareL0CSize(aicoreParams_.l0cSize);

  return ge::GRAPH_SUCCESS;
}

ge::graphStatus FlashAttentionScoreGradTilingUs1s2Bbn2::SetBmm31TilingData(uint32_t sOut,
                                                                           uint32_t l1SizeRemain) {
  /*
  BSH/BSND:
  A: [S1, S2] + B: [B, S2, N2, D] = C[B, S1, N1, D]
  A In UB, B\C In Gm
  For A: don't care (m = S1, k = S2)
  For B: n = N2 * D, k = S2
  For C: m = S1, n = N1 * D

  SBH:
  A: [S1, S2] + B: [S2, B, N2, D] = C[S1, B, N1, D]
  A In UB, B\C In Gm
  For A: don't care(m = S1, k = S2)
  For B: n = B * N2 * D, k = S2
  For C: m = S1, n = B * N1 * D

  BNSD
  m = S1, k = S2, n = D
  */
  uint32_t layout = td_.opInfo.get_layout();
  bool is_bsh = layout == static_cast<uint32_t>(InputLayout::BSH);
  bool is_sbh = layout == static_cast<uint32_t>(InputLayout::SBH);
  bool is_bsnd = layout == static_cast<uint32_t>(InputLayout::BSND);
  bool is_tnd = layout == static_cast<uint32_t>(InputLayout::TND);

  uint32_t ka = td_.opInfo.get_S2();
  uint32_t kb = td_.opInfo.get_S2();
  uint32_t m = td_.opInfo.get_S1();
  uint32_t n = td_.opInfo.get_D();
  if (is_bsh || is_bsnd || is_tnd) {
    // not support different N
    n = td_.opInfo.get_D() * td_.opInfo.get_N2();
  } else if (is_sbh) {
    // not support differnet N
    n = td_.opInfo.get_D() * td_.opInfo.get_N2() * td_.opInfo.get_B();
  }

  auto inputAType = matmul_tiling::DataType::DT_FLOAT16;
  auto inputBType = matmul_tiling::DataType::DT_FLOAT16;
  auto outputType = matmul_tiling::DataType::DT_FLOAT16;
  auto keyType = static_cast<uint32_t>(context_->GetInputDesc(KEY)->GetDataType());
  if (keyType == ge::DT_BF16) {
    inputAType = matmul_tiling::DataType::DT_BF16;
    inputBType = matmul_tiling::DataType::DT_BF16;
    outputType = matmul_tiling::DataType::DT_FLOAT;
  } else if (keyType == ge::DT_FLOAT) {
    inputAType = matmul_tiling::DataType::DT_FLOAT;
    inputBType = matmul_tiling::DataType::DT_FLOAT;
    outputType = matmul_tiling::DataType::DT_FLOAT;
  } else if (keyType == ge::DT_FLOAT16 && tmpData_.branch == BRANCH_FP16_HIGH_PRECISION) {
    inputAType = matmul_tiling::DataType::DT_FLOAT16;
    inputBType = matmul_tiling::DataType::DT_FLOAT16;
    outputType = matmul_tiling::DataType::DT_FLOAT;
  }

  matmul_tiling::MatmulApiTiling mm31;
  mm31.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, inputAType, false);
  mm31.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, inputBType, false);
  mm31.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, outputType);
  mm31.SetShape(singleM, td_.opInfo.get_D(), singleN);
  mm31.SetOrgShape(m, n, ka, kb);
  mm31.SetBias(false);
  mm31.SetBufferSpace(l1SizeRemain, aicoreParams_.l0cSize);
  if (mm31.GetTiling(td_.mm31TilingData) == -1) {
    OPS_LOG_I(context_, "FlashAttentionScoreGradTilingUs1s2Bbn2 mm31 GetTiling Failed");
    return ge::GRAPH_PARAM_INVALID;
  }
  td_.mm31TilingData.set_shareMode(0);
  td_.mm31TilingData.set_shareL1Size(l1SizeRemain);
  td_.mm31TilingData.set_shareL0CSize(aicoreParams_.l0cSize);

  return ge::GRAPH_SUCCESS;
}

ge::graphStatus FlashAttentionScoreGradTilingUs1s2Bbn2::SetBmm4TilingData(uint32_t sOut, uint32_t sFla,
                                                                          uint32_t l1SizeRemain) {
  /*
  BSH/BSND
  A:[S1, S2] + B: [B,S1,N1,D] = C[B,S2,N2,D]
  A In UB， B\C In GM
  A: don't care (m = S2, k = S1)
  B: n = N1 * D, k = S1
  C: m = S2, n = N2 * D

  SBH
  A:[S1, S2] + B: [S1, B, N1, D] = C: [S2, B, N2, D]
  A In UB， B\C In GM
  A: don't care (m= S2, k= S1)
  B: n = B * N1 * D, k = S1
  C: m = S2, n = B * N2 * D

  BNSD:
  m = S2, n = D, k = S1
  */

  uint32_t layout = td_.opInfo.get_layout();
  bool is_bsh = layout == static_cast<uint32_t>(InputLayout::BSH);
  bool is_sbh = layout == static_cast<uint32_t>(InputLayout::SBH);
  bool is_bsnd = layout == static_cast<uint32_t>(InputLayout::BSND);
  bool is_tnd = layout == static_cast<uint32_t>(InputLayout::TND);

  uint32_t ka = td_.opInfo.get_S1();
  uint32_t kb = td_.opInfo.get_S1();
  uint32_t m = td_.opInfo.get_S2();
  uint32_t n = td_.opInfo.get_D();
  if (is_bsh || is_bsnd || is_tnd) {
    // not support different N
    n = td_.opInfo.get_D() * td_.opInfo.get_N2() * td_.opInfo.get_G();
  } else if (is_sbh) {
    // not support differnet N
    n = td_.opInfo.get_D() * td_.opInfo.get_N2() * td_.opInfo.get_G() * td_.opInfo.get_B();
  }

  auto inputAType = matmul_tiling::DataType::DT_FLOAT16;
  auto inputBType = matmul_tiling::DataType::DT_FLOAT16;
  auto outputType = matmul_tiling::DataType::DT_FLOAT16;
  auto keyType = static_cast<uint32_t>(context_->GetInputDesc(KEY)->GetDataType());
  if (keyType == ge::DT_BF16) {
    inputAType = matmul_tiling::DataType::DT_BF16;
    inputBType = matmul_tiling::DataType::DT_BF16;
    outputType = matmul_tiling::DataType::DT_FLOAT;
  } else if (keyType == ge::DT_FLOAT) {
    inputAType = matmul_tiling::DataType::DT_FLOAT;
    inputBType = matmul_tiling::DataType::DT_FLOAT;
    outputType = matmul_tiling::DataType::DT_FLOAT;
  } else if (keyType == ge::DT_FLOAT16 && tmpData_.branch == BRANCH_FP16_HIGH_PRECISION) {
    inputAType = matmul_tiling::DataType::DT_FLOAT16;
    inputBType = matmul_tiling::DataType::DT_FLOAT16;
    outputType = matmul_tiling::DataType::DT_FLOAT;
  }

  matmul_tiling::MatmulApiTiling mm4;
  mm4.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, inputAType, true);
  mm4.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, inputBType, false);
  mm4.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, outputType);
  mm4.SetShape(singleN, td_.opInfo.get_D(), singleM);
  mm4.SetOrgShape(m, n, ka, kb);
  mm4.SetBias(false);
  mm4.SetBufferSpace(l1SizeRemain, aicoreParams_.l0cSize);
  if (mm4.GetTiling(td_.mm4TilingData) == -1) {
    OPS_LOG_I(context_, "FlashAttentionScoreGradTilingUs1s2Bbn2 mm4 GetTiling Failed");
    return ge::GRAPH_PARAM_INVALID;
  }
  td_.mm4TilingData.set_shareMode(0);
  td_.mm4TilingData.set_shareL1Size(l1SizeRemain);
  td_.mm4TilingData.set_shareL0CSize(aicoreParams_.l0cSize);

  return ge::GRAPH_SUCCESS;
}

REGISTER_TILING_TEMPLATE("FlashAttentionScoreGrad", FlashAttentionScoreGradTilingUs1s2Bbn2, 15000);
REGISTER_TILING_TEMPLATE("FlashAttentionScoreGrad", FlashAttentionScoreGradTilingDeterministic, 1000);
}  // namespace optiling
